{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 序列逆置 （加注意力的seq2seq）\n",
    "使用attentive sequence to sequence 模型将一个字符串序列逆置。例如 `OIMESIQFIQ` 逆置成 `QIFQISEMIO`(下图来自网络，是一个加attentino的sequence to sequence 模型示意图)\n",
    "![attentive seq2seq](./seq2seq-attn.jpg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import collections\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras import layers, optimizers, datasets\n",
    "import os,sys,tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 玩具序列数据生成\n",
    "生成只包含[A-Z]的字符串，并且将encoder输入以及decoder输入以及decoder输出准备好（转成index）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(['GLSUIOUHPB', 'QLXIKLZQBX'], <tf.Tensor: id=25627, shape=(2, 10), dtype=int32, numpy=\n",
      "array([[ 7, 12, 19, 21,  9, 15, 21,  8, 16,  2],\n",
      "       [17, 12, 24,  9, 11, 12, 26, 17,  2, 24]], dtype=int32)>, <tf.Tensor: id=25628, shape=(2, 10), dtype=int32, numpy=\n",
      "array([[ 0,  2, 16,  8, 21, 15,  9, 21, 19, 12],\n",
      "       [ 0, 24,  2, 17, 26, 12, 11,  9, 24, 12]], dtype=int32)>, <tf.Tensor: id=25629, shape=(2, 10), dtype=int32, numpy=\n",
      "array([[ 2, 16,  8, 21, 15,  9, 21, 19, 12,  7],\n",
      "       [24,  2, 17, 26, 12, 11,  9, 24, 12, 17]], dtype=int32)>)\n"
     ]
    }
   ],
   "source": [
    "import random\n",
    "import string\n",
    "\n",
    "def randomString(stringLength):\n",
    "    \"\"\"Generate a random string with the combination of lowercase and uppercase letters \"\"\"\n",
    "\n",
    "    letters = string.ascii_uppercase\n",
    "    return ''.join(random.choice(letters) for i in range(stringLength))\n",
    "\n",
    "def get_batch(batch_size, length):\n",
    "    batched_examples = [randomString(length) for i in range(batch_size)]\n",
    "    enc_x = [[ord(ch)-ord('A')+1 for ch in list(exp)] for exp in batched_examples]\n",
    "    y = [[o for o in reversed(e_idx)] for e_idx in enc_x]\n",
    "    dec_x = [[0]+e_idx[:-1] for e_idx in y]\n",
    "    return (batched_examples, tf.constant(enc_x, dtype=tf.int32), \n",
    "            tf.constant(dec_x, dtype=tf.int32), tf.constant(y, dtype=tf.int32))\n",
    "print(get_batch(2, 10))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 建立sequence to sequence 模型\n",
    "\n",
    "完成两空，模型搭建以及单步解码逻辑"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "class mySeq2SeqModel(keras.Model):\n",
    "    def __init__(self):\n",
    "        super(mySeq2SeqModel, self).__init__()\n",
    "        self.v_sz=27\n",
    "        self.hidden = 128\n",
    "        self.embed_layer = tf.keras.layers.Embedding(self.v_sz, 64, \n",
    "                                                    batch_input_shape=[None, None])\n",
    "        \n",
    "        self.encoder_cell = tf.keras.layers.SimpleRNNCell(self.hidden)\n",
    "        self.decoder_cell = tf.keras.layers.SimpleRNNCell(self.hidden)\n",
    "        \n",
    "        self.encoder = tf.keras.layers.RNN(self.encoder_cell, \n",
    "                                           return_sequences=True, return_state=True)\n",
    "        self.decoder = tf.keras.layers.RNN(self.decoder_cell, \n",
    "                                           return_sequences=True, return_state=True)\n",
    "        self.dense_attn = tf.keras.layers.Dense(self.hidden)\n",
    "        self.dense = tf.keras.layers.Dense(self.v_sz)\n",
    "        \n",
    "        \n",
    "    @tf.function\n",
    "    def call(self, enc_ids, dec_ids):\n",
    "        '''\n",
    "        todo\n",
    "        \n",
    "        完成带attention机制的 sequence2sequence 模型的搭建，模块已经在`__init__`函数中定义好，\n",
    "        用双线性attention，或者自己改一下`__init__`函数做加性attention\n",
    "        '''\n",
    "        return logits\n",
    "    \n",
    "    \n",
    "    @tf.function\n",
    "    def encode(self, enc_ids):\n",
    "        enc_emb = self.embed_layer(enc_ids) # shape(b_sz, len, emb_sz)\n",
    "        enc_out, enc_state = self.encoder(enc_emb)\n",
    "        return enc_out, [enc_out[:, -1, :], enc_state]\n",
    "    \n",
    "    def get_next_token(self, x, state, enc_out):\n",
    "        '''\n",
    "        shape(x) = [b_sz,] \n",
    "        '''\n",
    "    \n",
    "        '''\n",
    "        todo\n",
    "        参考sequence_reversal-exercise, 自己构建单步解码逻辑'''\n",
    "        return out, state"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Loss函数以及训练逻辑"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def compute_loss(logits, labels):\n",
    "    losses = tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
    "            logits=logits, labels=labels)\n",
    "    losses = tf.reduce_mean(losses)\n",
    "    return losses\n",
    "\n",
    "@tf.function\n",
    "def train_one_step(model, optimizer, enc_x, dec_x, y):\n",
    "    with tf.GradientTape() as tape:\n",
    "        logits = model(enc_x, dec_x)\n",
    "        loss = compute_loss(logits, y)\n",
    "\n",
    "    # compute gradient\n",
    "    grads = tape.gradient(loss, model.trainable_variables)\n",
    "    optimizer.apply_gradients(zip(grads, model.trainable_variables))\n",
    "    return loss\n",
    "\n",
    "def train(model, optimizer, seqlen):\n",
    "    loss = 0.0\n",
    "    accuracy = 0.0\n",
    "    for step in range(2000):\n",
    "        batched_examples, enc_x, dec_x, y = get_batch(32, seqlen)\n",
    "        loss = train_one_step(model, optimizer, enc_x, dec_x, y)\n",
    "        if step % 500 == 0:\n",
    "            print('step', step, ': loss', loss.numpy())\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 训练迭代"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step 0 : loss 3.3069763\n",
      "step 500 : loss 1.3605067\n",
      "step 1000 : loss 0.2817158\n",
      "step 1500 : loss 0.105303034\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=46985, shape=(), dtype=float32, numpy=0.05765622>"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "optimizer = optimizers.Adam(0.0005)\n",
    "model = mySeq2SeqModel()\n",
    "train(model, optimizer, seqlen=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 测试模型逆置能力\n",
    "首先要先对输入的一个字符串进行encode，然后在用decoder解码出逆置的字符串\n",
    "\n",
    "测试阶段跟训练阶段的区别在于，在训练的时候decoder的输入是给定的，而在预测的时候我们需要一步步生成下一步的decoder的输入"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[False, True, True, True, True, True, True, True, True, True, True, True, True, True, False, True, True, True, True, True, True, True, True, False, False, True, True, True, True, True, True, True]\n",
      "[('LGNVPEUAIFNLDIRQPMEY', 'YEMPQRIDLNFIAUEPVNGL'), ('ZFKNSOQGHINBCDJMOQXO', 'OXQOMJDCBNIHGQOSNKFZ'), ('ZXGDWLBRJMDSNYXIEMKJ', 'JKMEIXYNSDMJRBLWDGXZ'), ('KSEOFXPRMGUBHSJWIGYC', 'CYGIWJSHBUGMRPXFOESK'), ('HHMACLLSVMESLETLYJLZ', 'ZLJYLTELSEMVSLLCAMHH'), ('PVIABJTRCEQZFSJZLKQP', 'PQKLZJSFZQECRTJBAIVP'), ('PTUYDMQCQKZWBNCJRZBG', 'GBZRJCNBWZKQCQMDYUTP'), ('ZFZEWFPTWQFFSNNQLOFW', 'WFOLQNNSFFQWTPFWEZFZ'), ('NAHVTGYNHHEXOKMOEDRN', 'NRDEOMKOXEHHNYGTVHAN'), ('CTDDCDXRMIQDKCHODXVU', 'UVXDOHCKDQIMRXDCDDTL'), ('QWMRORZQQQTSJEHNVSZN', 'NZSVNHEJSTQQQZRORMWQ'), ('HTLTKWDCPXYCOREQGOUA', 'AUOGQEROCYXPCDWKTLTH'), ('DIQEGEZTNCFGNHHCXEDE', 'EDEXCHHNGFCNTZEGEQIJ'), ('HLGMERZQBOYYDQWMKKVE', 'EVKKMWQDYYOBQZREMGLH'), ('MCUCLWJZZVBMJFTNMAMU', 'UMAMNTFJMBVZZJWLCUCY'), ('UJIKEJYXYBBUSNIDOMBL', 'LBMODINSUBBYXYJEKIJU'), ('XPBYWUZGZJMRVMCVONLM', 'MLNOVCMVRMJZGZUWYBPX'), ('GDEKKHYSSGHPDRIZSZLY', 'YLZSZIRDPHGSSYHKKEDG'), ('RVPPCAFTWYOLOFDBDRGR', 'RGRDBDFOLOYWTFACPPVR'), ('SHZHNIHVWGCNCFQJQGQZ', 'ZQGQJQFCNCGWVHINHBHS'), ('NAQPBIBCNFGVUDJKPMGZ', 'ZGMPKJDUVGFNCBIBPQAN'), ('INGPKHNHAVQCGQZHPHCG', 'GCHPHZQGCQVAHNHKPGNI'), ('KWWCWEGNMIMJUMYJETJQ', 'QJTEJYMUJMIMNGEWCWWK'), ('TETRFQIESYAPMFRUYSZE', 'EZSYURFMPAYSEIQFRTET'), ('QSRNBFLVLALCFPSZQNMN', 'NMNQZSPFCLALVLFBNRSQ'), ('SSZQYNNEMFKXJGTELUTF', 'FTULETGJXKFMENNYQZSS'), ('MBWHGDXBIOYHLEULXADG', 'GDAXLUELHYOIBXDGHWBM'), ('TTTFGSHTCLNAAMUMYNFU', 'UFNYMUMAANLCTHSGFTTT'), ('IEGUFJYJAIOGLNZYTKZC', 'CZKTYZNLGOIAJYJFUGEI'), ('MHBPGXFEZNNZHMXMQYTM', 'MTYQMXMHZNNZEFXGPBHM'), ('LUMXHUIUXCCBRUXDVOOC', 'COOVDXURBCCXUIUHXMJL'), ('QJTLJGVXGRYIYXWTGTQS', 'SQTGTWXYIYRGXVGJLTJQ')]\n"
     ]
    }
   ],
   "source": [
    "def sequence_reversal():\n",
    "    def decode(init_state, steps, enc_out):\n",
    "        b_sz = tf.shape(init_state[0])[0]\n",
    "        cur_token = tf.zeros(shape=[b_sz], dtype=tf.int32)\n",
    "        state = init_state\n",
    "        collect = []\n",
    "        for i in range(steps):\n",
    "            cur_token, state = model.get_next_token(cur_token, state, enc_out)\n",
    "            collect.append(tf.expand_dims(cur_token, axis=-1))\n",
    "        out = tf.concat(collect, axis=-1).numpy()\n",
    "        out = [''.join([chr(idx+ord('A')-1) for idx in exp]) for exp in out]\n",
    "        return out\n",
    "    \n",
    "    batched_examples, enc_x, _, _ = get_batch(32, 20)\n",
    "    enc_out, state = model.encode(enc_x)\n",
    "    return decode(state, enc_x.get_shape()[-1], enc_out), batched_examples\n",
    "\n",
    "def is_reverse(seq, rev_seq):\n",
    "    rev_seq_rev = ''.join([i for i in reversed(list(rev_seq))])\n",
    "    if seq == rev_seq_rev:\n",
    "        return True\n",
    "    else:\n",
    "        return False\n",
    "print([is_reverse(*item) for item in list(zip(*sequence_reversal()))])\n",
    "print(list(zip(*sequence_reversal())))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
